from typing import Dict, Sequence, Optional

import os
import sys

sys.path.extend([".", ".."])

import time
import numpy as np
import torch
from torch import optim
import torch.multiprocessing as mp

from tqdm import tqdm
import itertools
from copy import deepcopy
from omegaconf import OmegaConf, ListConfig, DictConfig
import pandas as pd


from neural_fields.nf_utils import ACTS
from neural_fields.data import CycloneNFDataset, CycloneNFDataLoader
from neural_fields.nf_train import train_nf
from neural_fields.models.mlp import MLPNF
from neural_fields.models.siren import SIREN
from neural_fields.models.wire import WIRE


FLUX_FIELDS = False
KY_MODES = {
    "base": None,
    "zfout": [0],
    "first2": [0, 1],
    "first5": [0, 1, 2, 3, 4, 5],
    "fancy1": [0, 1, 2, [3, 4, 5]],
    "fancy2": [0, 1, 2, [3, 4], [5, 6, 7, 8]],
}


def run(
    cfg: DictConfig,
    trajectory: str,
    timestep: int,
    is_grid: bool = False,
    verbose: bool = True,
    shared_init: Optional[str] = None,
):
    if isinstance(timestep, Sequence):
        timestep = timestep[0]
    device = torch.device(f"cuda:{torch.cuda.current_device()}")
    data = CycloneNFDataset(
        trajectory,
        timesteps=timestep,
        normalize=cfg.normalization,
        normalize_coords="discrete" not in cfg.embed_type,
        separate_ky_modes=KY_MODES[cfg.ky_filter],
        flux_fields=FLUX_FIELDS,
        realpotens=True,
    )
    loader = CycloneNFDataLoader(data, cfg.batch_size, preload=True, shuffle=True)

    if cfg.name == "siren":
        model = SIREN(
            data.ndim,
            data.nchannels,
            n_layers=cfg.n_layers,
            dim=cfg.dim,
            first_w0=cfg.first_w0,
            hidden_w0=cfg.hidden_w0,
            readout_w0=cfg.hidden_w0,
            skips=cfg.skips,
            embed_type=cfg.embed_type,
            clip_out=False,
        )
    if cfg.name == "wire":
        model = WIRE(
            data.ndim,
            data.nchannels // 2,
            n_layers=cfg.n_layers,
            dim=cfg.dim,
            first_w0=cfg.first_w0,
            hidden_w0=cfg.hidden_w0,
            readout_w0=cfg.hidden_w0,
            complex_out=False,
            skips=cfg.skips,
            learnable_w0_s0=True,
        )
    if cfg.name == "mlp":
        model = MLPNF(
            data.ndim,
            data.nchannels,
            n_layers=cfg.n_layers,
            dim=cfg.dim,
            act_fn=ACTS[cfg.act_fn],
            use_checkpoint=False,
            skips=cfg.skips,
            embed_type=cfg.embed_type,
        )

    if shared_init is not None:
        model.load_state_dict(torch.load(shared_init))

    n_params = sum(p.numel() for p in model.parameters())
    compression = data.full_df.numel() / n_params
    if verbose:
        print(f"Params: {n_params / 1e3:.2f}k, compression: {compression:.2f}x")

    opt = optim.AdamW(model.parameters(), cfg.lr, weight_decay=1e-8)
    sched = optim.lr_scheduler.CosineAnnealingLR(opt, cfg.epochs, 1e-8)

    # Normal training
    model, best_model, losses_pre = train_nf(
        model,
        n_epochs=cfg.epochs,
        data=data,
        loader=loader,
        device=device,
        field_subsamples=np.linspace(1.0, 0.2, cfg.epochs),
        use_flux_fields=FLUX_FIELDS,
        field_loss=True,
        physical_loss=False,
        optim=opt,
        sched=sched,
        use_tqdm=False,
        use_print=verbose,
    )
    model_pre = deepcopy(model)
    best_model_pre = deepcopy(best_model)
    # finetune
    int_epochs = cfg.int_epochs
    integral_loss_weight = {"flux": 1.0, "phi": 1.0}
    physical_loss_weight = {
        "kyspec": 1.0,
        "qspec": 1.0,
        "kyspec monotonicity": 1.0,
        "qspec monotonicity": 1.0,
    }
    if hasattr(cfg, "physical_losses"):
        if len(cfg.physical_losses) == 0:
            int_epochs = 0
        if "int" not in cfg.physical_losses:
            del integral_loss_weight["flux"]
            del integral_loss_weight["phi"]
        if "diag" not in cfg.physical_losses:
            del physical_loss_weight["kyspec"]
            del physical_loss_weight["qspec"]
        if "mono" not in cfg.physical_losses:
            del physical_loss_weight["kyspec monotonicity"]
            del physical_loss_weight["qspec monotonicity"]

    losses_fine = {}
    if int_epochs > 0:
        aux_sched = None
        use_conflictfree = cfg.use_conflictfree if hasattr(cfg, "use_conflictfree") else "none"
        aux_opt = optim.AdamW(model.parameters(), cfg.lr / 100, weight_decay=1e-8)
        # aux_sched = optim.lr_scheduler.CosineAnnealingLR(aux_opt, cfg.epochs, 1e-12)
        model, best_model, losses_fine = train_nf(
            best_model,
            n_epochs=cfg.int_epochs,
            data=data,
            loader=loader,
            device=device,
            use_flux_fields=FLUX_FIELDS,
            field_loss=False,
            physical_loss=True,
            aux_optim=aux_opt,
            aux_sched=aux_sched,
            use_tqdm=False,
            use_print=verbose,
            integral_loss_weight=integral_loss_weight,
            physical_loss_weight=physical_loss_weight,
            use_conflictfree=use_conflictfree
        )

    if not is_grid:
        os.makedirs(cfg.ckp_path, exist_ok=True)
        fname = trajectory.replace("_ifft", "").replace("_realpotens", "")
        fname = fname.split(".")[0]
        model_name = f"{cfg.name.lower()}_{fname}_t{timestep}_x{int(compression)}"
        torch.save(
            {"state_dict": model_pre.state_dict(), "cfg": cfg},
            f"{cfg.ckp_path}/{model_name}.pt",
        )
        torch.save(
            {"state_dict": best_model_pre.state_dict(), "cfg": cfg},
            f"{cfg.ckp_path}/best_{model_name}.pt",
        )
        if cfg.int_epochs > 0:
            torch.save(
                {"state_dict": model.state_dict(), "cfg": cfg},
                f"{cfg.ckp_path}/int_{model_name}.pt",
            )
            torch.save(
                {"state_dict": best_model.state_dict(), "cfg": cfg},
                f"{cfg.ckp_path}/best_int_{model_name}.pt",
            )
    losses_pre["CR"] = compression
    losses_fine["CR"] = compression
    return losses_pre, losses_fine


def grid_worker(
    combo_cfg: DictConfig,
    trajectories: Sequence[str],
    timesteps: Sequence[int],
    gpu: int,
    return_dict: Dict,
    key: int,
):
    torch.cuda.set_device(int(gpu))

    metric_sums_pre = {}
    metric_sums_fine = {}

    for traj in trajectories:
        for timestep in timesteps:
            metrics_pre, metrics_fine = run(
                combo_cfg, traj, timestep, is_grid=True, verbose=False
            )

            for k, v in metrics_pre.items():
                metric_sums_pre[k] = metric_sums_pre.get(k, 0.0) + float(v)

            if metrics_fine is not None:
                for k, v in metrics_fine.items():
                    metric_sums_fine[k] = metric_sums_fine.get(k, 0.0) + float(v)

    nums = len(trajectories) * len(timesteps)
    avg_metrics = {f"pre_{k}": total / nums for k, total in metric_sums_pre.items()}
    avg_metrics.update(
        {f"fine_{k}": total / nums for k, total in metric_sums_fine.items()}
    )

    return_dict[key] = avg_metrics


def grid(cfg: DictConfig):
    # Grid search parameters
    grid_params = {
        k: v
        for k, v in cfg.items()
        if isinstance(v, ListConfig) and k not in ["timesteps", "trajectory", "gpus"]
    }
    fixed_params = {
        k: v
        for k, v in cfg.items()
        if not isinstance(v, ListConfig) or k in ["timesteps", "trajectory", "gpus"]
    }

    param_names = list(grid_params.keys())
    param_values = list(grid_params.values())
    combinations = list(itertools.product(*param_values))

    if hasattr(cfg, "timesteps"):
        timesteps = cfg.timesteps
    else:
        timesteps = list(range(100, 100 + cfg.timeframe * cfg.coarse, cfg.coarse))

    # parallelize over grid search
    ctx = mp.get_context("spawn")
    manager = mp.Manager()
    return_dict = manager.dict()

    active_processes = []
    results = []
    pbar = tqdm(total=len(combinations), desc="Grid search")

    for job_id, combo in enumerate(combinations):
        combo_cfg = deepcopy(fixed_params)
        combo_cfg.update(dict(zip(param_names, combo)))
        combo_cfg = OmegaConf.create(combo_cfg)

        while len(active_processes) >= len(cfg.gpus) * cfg.throttling:
            for p in active_processes:
                if not p.is_alive():
                    p.join()
                    pbar.update()
            active_processes = [p for p in active_processes if p.is_alive()]
            time.sleep(1.0)

        gpu = cfg.gpus[job_id % len(cfg.gpus)]
        p = ctx.Process(
            target=grid_worker,
            args=(
                combo_cfg,
                cfg.trajectory,
                timesteps,
                gpu,
                return_dict,
                job_id,
            ),
        )
        p.start()
        active_processes.append(p)

    for p in active_processes:
        p.join()
        pbar.update()

    # aggregate
    for combo_id, combo in enumerate(combinations):
        if combo_id in return_dict:
            avg_metrics = return_dict[combo_id]
            results.append({**dict(zip(param_names, combo)), **avg_metrics})
        pbar.update()

    pbar.close()

    grid_df = pd.DataFrame(results)
    print(grid_df)
    tag = ""
    if hasattr(cfg, "physical_losses") and len(cfg.physical_losses) > 1:
        tag = "_pinn"
    grid_df.to_csv(f"grid_search_{cfg.name}{tag}.csv", index=False)


def worker(cfg: DictConfig, traj: str, timesteps: Sequence, gpu: int):
    torch.cuda.set_device(int(gpu))
    for timestep in timesteps:
        shared_init = None
        if hasattr(cfg, "use_shared_init") and cfg.use_shared_init:
            shared_init = f"nf_shared_init/{traj.replace('.h5', '')}.pth"
        _ = run(
            cfg,
            traj,
            [int(timestep)],
            is_grid=False,
            verbose=False,
            shared_init=shared_init,
        )


def main(cfg: DictConfig):
    if hasattr(cfg, "timesteps"):
        timesteps = cfg.timesteps
    else:
        timesteps = list(range(100, 100 + cfg.timeframe * cfg.coarse, cfg.coarse))

    ctx = mp.get_context("spawn")
    active_processes = []

    total_jobs = len(cfg.trajectory) * len(timesteps)
    pbar = tqdm(total=total_jobs, desc="Parallel evaluation")

    job_id = 0
    for traj in cfg.trajectory:
        for timestep_chunk in np.array_split(timesteps, cfg.throttling):
            while len(active_processes) >= len(cfg.gpus) * cfg.throttling:
                for p in active_processes:
                    if not p.is_alive():
                        p.join()
                        pbar.update()
                active_processes = [p for p in active_processes if p.is_alive()]
                time.sleep(0.5)

            gpu = cfg.gpus[job_id % len(cfg.gpus)]
            p = ctx.Process(
                target=worker,
                args=(cfg, traj, timestep_chunk, gpu),
            )
            p.start()
            active_processes.append(p)
            job_id += 1

    for p in active_processes:
        p.join()
        pbar.update()

    pbar.close()

    # raw_size = os.path.getsize(cyclone.files[0]) / len(cyclone) * len(timesteps)
    # ckpt_files = [
    #     f for f in os.listdir(cfg.ckp_path) if f.endswith(".pt") and cfg.trajectory in f
    # ]
    # ckpt_size = sum(os.path.getsize(os.path.join(cfg.ckp_path, f)) for f in ckpt_files)
    # ckpt_size = ckpt_size / 2

    # compression = raw_size / ckpt_size
    # saved = raw_size - ckpt_size

    # print(f"Trajectory {cfg.trajectory} compressed")
    # print(f"Raw size ({len(timesteps)} timesteps): {raw_size / 1e6:.2f}MB")
    # print(f"Models size: {ckpt_size / 1e6:.2f}MB")
    # print(f"Compression: {compression:.2f}x, saved: {saved / 1e6:.2f}MB")


if __name__ == "__main__":
    cli_cfg = OmegaConf.from_cli()
    cfg = OmegaConf.load(cli_cfg.get("config", "nf/eval.yaml"))
    cfg = OmegaConf.merge(cfg, cli_cfg)
    print("#" * 88)
    print(OmegaConf.to_yaml(cfg))
    print("#" * 88)
    if cfg.mode == "default":
        main(cfg)
    elif cfg.mode == "grid":
        grid(cfg)
    else:
        raise NotImplementedError(cfg.mode)
